Master's thesis case study 3: Bandit's with stopping¶
In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
import numpy
import torch
from adaptive_nof1 import *
from adaptive_nof1.policies import *
from adaptive_nof1.helpers import *
from adaptive_nof1.inference import *
from adaptive_nof1.metrics import *
from matplotlib import pyplot as plt
import seaborn
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
In [3]:
# Setup generic n-of-1 parameters
block_length = 5
max_length = 10 * block_length
number_of_actions = 2
number_of_patients = 100
In [4]:
# Scenarios
class NormalModel(Model):
def __init__(self, patient_id, mean, variance):
self.rng = numpy.random.default_rng(patient_id)
self.mean = mean
self.variance = variance
self.patient_id = patient_id
def multivariate_normal_distribution(debug_data):
cov = torch.diag_embed(torch.tensor(self.variance))
return torch.distributions.MultivariateNormal(torch.tensor(self.mean), cov)
def generate_context(self, history):
return {}
@property
def additional_config(self):
return {"expectations_of_interventions": self.mean}
@property
def number_of_interventions(self):
return len(self.mean)
def observe_outcome(self, action, context):
treatment_index = action["treatment"]
return {"outcome": self.rng.normal(self.mean[treatment_index], self.variance[treatment_index])}
def __str__(self):
return f"NormalModel({self.mean, self.variance})"
generating_scenario_I = lambda patient_id: NormalModel(patient_id, mean=[0, 0], variance=[1,1])
generating_scenario_II = lambda patient_id: NormalModel(patient_id, mean=[1, 0], variance=[1,1])
generating_scenario_III = lambda patient_id: NormalModel(patient_id, mean=[2, 0], variance=[1,1])
In [5]:
# Inference Model
inference_model = lambda: NormalKnownVariance(prior_mean=0, prior_variance=1, variance=1)
# Stopping Time
ALPHA_STOPPING = 0.01
def alpha_stopping_time(history, context):
model = NormalKnownVariance(prior_mean=0, prior_variance=1, variance=1)
model.update_posterior(history, number_of_actions)
probabilities = model.approximate_max_probabilities(number_of_actions, context)
return 1 - max(probabilities) < ALPHA_STOPPING
In [6]:
# Policies
fixed_policy = StoppingPolicy(
policy = BlockPolicy(
block_length = block_length,
internal_policy = FixedPolicy(
number_of_actions=2,
inference_model = inference_model(),
),
),
stopping_time = alpha_stopping_time,
)
explore_then_commit = StoppingPolicy(
policy= BlockPolicy(
block_length = block_length,
internal_policy = ExploreThenCommit(
number_of_actions=2,
exploration_length=4,
block_length = block_length,
inference_model = inference_model(),
),
),
stopping_time = alpha_stopping_time,
)
thompson_sampling_policy = StoppingPolicy(
policy = BlockPolicy(
block_length = block_length,
internal_policy = ThompsonSampling(
inference_model=inference_model(),
number_of_actions=2,
),
),
stopping_time = alpha_stopping_time,
)
ucb_policy = StoppingPolicy(
policy = BlockPolicy(
block_length = block_length,
internal_policy = UpperConfidenceBound(
inference_model=inference_model(),
number_of_actions=2,
epsilon=0.05,
),
),
stopping_time = alpha_stopping_time,
)
In [7]:
# Full crossover study
study_designs = {
"n_patients": [number_of_patients],
"policy": [fixed_policy, explore_then_commit, thompson_sampling_policy, ucb_policy],
"model_from_patient_id": [
generating_scenario_I, generating_scenario_II, generating_scenario_III,
]
}
configurations = generate_configuration_cross_product(study_designs)
In [8]:
calculated_series, config_to_simulation_data = simulate_configurations(
configurations, max_length
)
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/50 [00:00<?, ?it/s]
[3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 1.95287569] [2.10230224 3.28970725] [1.68757613 3.28970725] [1.73281824 3.28970725] [1.89482357 3.28970725] [3.28970725 1.78263585] [2.11100749 3.28970725] [1.69909177 3.28970725] [3.28970725 0.7356268 ] [3.28970725 1.84953347] [3.28970725 1.48703369] [3.28970725 2.2207267 ] [2.60626885 3.28970725] [2.10139188 3.28970725] [1.06305465 3.28970725] [3.28970725 1.5903371 ] [2.57302619 3.28970725] [3.28970725 1.54313908] [1.92191388 3.28970725] [3.28970725 2.10131065] [3.28970725 2.41437005] [2.20606519 3.28970725] [3.28970725 1.40577887] [1.72323411 3.28970725] [1.91971212 3.28970725] [1.51226677 3.28970725] [1.76947968 3.28970725] [3.28970725 2.38250763] [1.66243489 3.28970725] [1.81691371 3.28970725] [2.06054612 3.28970725] [1.96420362 3.28970725] [2.19175514 3.28970725] [3.28970725 1.73492909] [2.3192803 3.28970725] [3.28970725 2.41342734] [1.78918069 3.28970725] [2.04259023 3.28970725] [2.4482715 3.28970725] [3.28970725 1.6629994 ] [3.28970725 1.37933528] [3.28970725 1.61955171] [3.28970725 1.75311485] [1.49172375 3.28970725] [3.28970725 2.55828373] [1.87732424 3.28970725] [3.28970725 1.91978472] [3.28970725 0.94912561] [2.38625532 3.28970725] [2.05114842 3.28970725] [3.28970725 2.07848305] [3.28970725 1.88759929] [3.28970725 1.8492947 ] [1.60664441 3.28970725] [2.13719975 3.28970725] [3.28970725 1.54683802] [3.28970725 2.3376584 ] [3.28970725 2.55935095] [3.28970725 1.12760385] [2.17969147 3.28970725] [1.39754359 3.28970725] [3.28970725 1.34065998] [1.20058969 3.28970725] [3.28970725 2.33910739] [1.89298032 3.28970725] [1.60545794 3.28970725] [3.28970725 1.43946271] [1.98447902 3.28970725] [3.28970725 2.71502871] [2.1064876 3.28970725] [3.28970725 2.04738365] [2.06403838 3.28970725] [1.33063265 3.28970725] [3.28970725 1.90824339] [3.28970725 1.93508305] [3.28970725 2.0257046 ] [2.54237517 3.28970725] [2.17975653 3.28970725] [1.74457456 3.28970725] [2.09638938 3.28970725] [1.51829405 3.28970725] [3.28970725 2.31133553] [2.30040541 3.28970725] [3.28970725 1.49083153] [1.17818228 3.28970725] [1.62489326 3.28970725] [1.69940162 3.28970725] [3.28970725 2.20661127] [1.5382256 3.28970725] [2.17364767 3.28970725] [2.15845385 3.28970725] [3.28970725 1.85660762] [0.82429786 3.28970725] [3.28970725 2.36626936] [3.28970725 2.13584086] [2.12926729 3.28970725] [3.28970725 1.49912721] [1.62173627 3.28970725] [1.80987343 3.28970725] [1.68533444 3.28970725] [2.02624912 1.95287569] [2.10230224 2.11053662] [1.73281824 1.91737436] [1.89482357 1.611254 ] [2.11163049 1.78263585] [2.11100749 2.64304136] [1.69909177 1.80166771] [1.74767096 1.84953347] [1.73330232 2.2207267 ] [2.10139188 2.30381042] [2.23221129 1.5903371 ] [1.59601251 1.54313908] [1.92191388 2.84914979] [1.86102076 2.10131065] [1.72323411 2.09355623] [1.91971212 1.74047112] [1.76947968 2.16002627] [1.76004928 2.38250763] [1.66243489 1.86662601] [1.81691371 1.90788863] [2.06054612 2.57566623] [1.96420362 2.17478056] [2.19175514 1.4523873 ] [2.3192803 1.88114517] [1.78918069 1.88086876] [2.4482715 2.36225855] [1.64481113 1.37933528] [1.47997932 1.61955171] [1.52559143 1.75311485] [1.87732424 2.09580938] [1.70446898 1.91978472] [1.40944019 0.94912561] [2.38625532 2.05004173] [2.05114842 1.23511877] [2.17641755 2.07848305] [1.64283545 1.88759929] [1.99808681 1.8492947 ] [2.13719975 2.17345956] [2.43327782 2.3376584 ] [1.20058969 1.25234479] [2.9694922 2.33910739] [1.89298032 1.84571926] [1.60545794 1.6712577 ] [2.05169945 1.43946271] [1.98447902 2.22525195] [2.1064876 2.09326514] [1.73038114 2.04738365] [2.06403838 1.67873992] [2.00469347 1.93508305] [2.17975653 1.54663394] [2.09638938 1.72174411] [2.41144931 2.31133553] [1.62489326 1.89182652] [1.69940162 1.90144359] [2.05437533 2.20661127] [2.17364767 2.02901327] [2.15845385 2.52187589] [1.58988509 1.85660762] [1.67909647 2.13584086] [2.12926729 2.25337757] [1.62173627 1.40613098] [1.80987343 2.25089758] [1.68533444 2.00482809] [1.45546191 1.95287569] [2.10230224 1.82555816] [1.73281824 1.59974416] [1.88892068 1.78263585] [1.69909177 1.7296998 ] [1.74767096 2.00002304] [2.10139188 2.11807736] [1.53382306 1.54313908] [1.86102076 2.02638065] [1.72323411 1.92435447] [2.01659982 1.74047112] [1.66243489 1.93983713] [1.81691371 1.32841053] [2.06054612 2.30183742] [1.78918069 1.83713528] [2.20127104 2.36225855] [1.57372887 1.37933528] [1.47997932 1.65577113] [1.87732424 1.66195466] [1.70446898 1.63232217] [1.98065583 2.05004173] [1.50174701 2.07848305] [1.98978005 1.8492947 ] [2.13719975 1.79635951] [1.84616565 2.3376584 ] [1.20058969 1.32280047] [1.44770257 1.84571926] [1.98447902 1.7331086 ] [2.02162664 2.09326514] [1.73038114 1.98219429] [1.97301153 1.67873992] [2.08266903 1.93508305] [2.03758313 1.72174411] [2.14193374 2.31133553] [1.62489326 1.79524291] [1.69940162 1.68381888] [2.05437533 2.00918149] [1.8872172 2.02901327] [2.15845385 2.39983754] [1.67909647 1.67441897] [2.12926729 1.87552549] [1.5104976 1.40613098] [1.80987343 2.00936299] [1.45546191 1.80025815] [1.83137204 1.82555816] [1.74833863 1.59974416] [1.55662779 1.78263585] [1.69909177 1.4346049 ] [1.74767096 2.01881252] [2.10139188 1.89434984] [1.53382306 1.74301719] [1.86102076 1.86095275] [1.72323411 1.74201709] [1.63135659 1.32841053] [1.78918069 1.62235714] [1.54959548 1.37933528] [1.76059101 1.66195466] [1.78097328 1.63232217] [1.98065583 2.25689838] [1.84181141 1.8492947 ] [1.81677093 1.79635951] [1.91809598 1.7331086 ] [2.02162664 2.11153002] [1.84854473 1.67873992] [1.90262136 1.93508305] [2.14193374 2.15972344] [1.62489326 1.41831039] [1.72172626 2.00918149] [2.15845385 2.0209795 ] [1.41075384 1.67441897] [1.92568072 1.87552549] [1.58256047 1.40613098] [1.80987343 1.91561085] [1.73072034 1.82555816] [1.55662779 1.68338809] [1.40932378 1.4346049 ] [1.70355914 1.86095275] [1.72323411 1.59888478] [1.52677934 1.62235714] [1.58164026 1.66195466] [1.75941664 1.63232217] [1.98065583 1.99357756] [1.84181141 1.83241122] [1.81810498 1.79635951] [1.82636405 1.7331086 ] [2.02162664 1.9906502 ] [1.83525043 1.67873992] [1.90262136 1.84632576] [2.14193374 1.93735343] [1.669262 1.41831039] [1.72172626 1.62149241] [1.92089956 2.0209795 ] [1.82186837 1.87552549] [1.80987343 1.97651927] [1.73072034 1.64064982] [1.55662779 1.53235852] [1.40932378 1.333351 ] [1.70355914 1.58722513] [1.99370799 1.59888478] [1.52677934 1.57759901] [1.58164026 1.47487419] [1.78926389 1.63232217] [1.98065583 1.97807017] [1.70095359 1.83241122] [1.7484532 1.7331086] [1.92583206 1.9906502 ] [1.7634656 1.67873992] [1.83170639 1.84632576] [1.95902034 1.93735343] [1.67747836 1.62149241] [1.92089956 1.98973009] [1.82186837 1.66252411] [1.45102069 1.53235852] [1.31203018 1.333351 ] [1.57912507 1.58722513] [1.36728217 1.47487419] [1.91416668 1.97807017] [1.63330179 1.7331086 ] [1.92583206 1.96123597] [1.7621056 1.67873992] [1.83170639 1.87612481] [1.92625841 1.93735343] [1.64559877 1.62149241] [1.92089956 1.76803282] [1.71593963 1.66252411] [1.45102069 1.38165296] [1.31203018 1.36821809] [1.57912507 1.47486883] [1.91416668 2.0787459 ] [1.92583206 1.73835468] [1.81095685 1.67873992] [1.83170639 1.92981191] [1.92625841 1.94926123] [1.68400988 1.62149241] [1.88739486 1.76803282] [1.69474064 1.66252411] [1.31203018 1.42522878] [1.63272255 1.47486883] [1.74353077 1.67873992] [1.83170639 1.8180798 ] [1.92625841 1.86139177] [1.65777542 1.62149241] [1.64666145 1.76803282] [1.75434398 1.66252411]
0%| | 0/50 [00:00<?, ?it/s]
[3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 1.95287569] [3.28970725 2.10230224] [3.28970725 1.68757613] [3.28970725 1.73281824] [2.7281569 3.28970725] [3.28970725 1.78263585] [2.94434083 3.28970725] [3.28970725 1.69909177] [3.28970725 0.7356268 ] [3.28970725 1.84953347] [2.32036703 3.28970725] [3.28970725 2.2207267 ] [3.43960218 3.28970725] [3.28970725 2.10139188] [1.89638799 3.28970725] [3.28970725 1.5903371 ] [3.40635952 3.28970725] [3.28970725 1.54313908] [2.75524721 3.28970725] [3.28970725 2.10131065] [3.24770338 3.28970725] [3.28970725 2.20606519] [3.28970725 1.40577887] [2.55656745 3.28970725] [2.75304545 3.28970725] [3.28970725 1.51226677] [2.60281301 3.28970725] [3.21584096 3.28970725] [3.28970725 1.66243489] [3.28970725 1.81691371] [2.89387946 3.28970725] [3.28970725 1.96420362] [3.02508847 3.28970725] [2.56826242 3.28970725] [3.28970725 2.3192803 ] [3.24676068 3.28970725] [2.62251402 3.28970725] [3.28970725 2.04259023] [3.28160483 3.28970725] [2.49633273 3.28970725] [3.28970725 1.37933528] [2.45288504 3.28970725] [2.58644818 3.28970725] [2.32505708 3.28970725] [3.28970725 2.55828373] [2.71065757 3.28970725] [2.75311805 3.28970725] [3.28970725 0.94912561] [3.21958866 3.28970725] [3.28970725 2.05114842] [3.28970725 2.07848305] [3.28970725 1.88759929] [3.28970725 1.8492947 ] [3.28970725 1.60664441] [2.97053309 3.28970725] [2.38017135 3.28970725] [3.17099173 3.28970725] [3.39268428 3.28970725] [1.96093718 3.28970725] [3.01302481 3.28970725] [3.28970725 1.39754359] [3.28970725 1.34065998] [2.03392302 3.28970725] [3.17244072 3.28970725] [2.72631365 3.28970725] [3.28970725 1.60545794] [2.27279604 3.28970725] [3.28970725 1.98447902] [3.54836204 3.28970725] [3.28970725 2.1064876 ] [2.88071698 3.28970725] [2.89737172 3.28970725] [2.16396599 3.28970725] [2.74157672 3.28970725] [2.76841638 3.28970725] [3.28970725 2.0257046 ] [3.28970725 2.54237517] [3.01308987 3.28970725] [3.28970725 1.74457456] [3.28970725 2.09638938] [2.35162738 3.28970725] [3.28970725 2.31133553] [3.28970725 2.30040541] [3.28970725 1.49083153] [2.01151561 3.28970725] [3.28970725 1.62489326] [3.28970725 1.69940162] [3.28970725 2.20661127] [2.37155893 3.28970725] [3.006981 3.28970725] [2.99178718 3.28970725] [3.28970725 1.85660762] [1.65763119 3.28970725] [3.19960269 3.28970725] [2.96917419 3.28970725] [2.96260062 3.28970725] [3.28970725 1.49912721] [2.45506961 3.28970725] [3.28970725 1.80987343] [2.51866778 3.28970725] [2.75070769 1.73281824] [2.94434083 2.64304136] [2.32036703 2.2276421 ] [2.87270141 3.28970725] [1.89638799 2.24126513] [2.84786216 3.28970725] [2.75524721 2.84914979] [2.69435409 2.10131065] [2.4064104 2.20606519] [2.89387946 2.57566623] [2.56826242 2.52380305] [2.71447851 2.3192803 ] [2.50883207 2.04259023] [2.49633273 2.33200857] [2.32505708 2.19181503] [2.59318286 2.55828373] [2.06845211 2.05114842] [2.38017135 2.43472707] [3.04272342 3.28970725] [1.96093718 1.8996458 ] [3.17244072 2.9694922 ] [2.27279604 2.05169945] [3.18584349 3.28970725] [2.16396599 2.24016872] [2.80707162 2.54237517] [2.55507744 2.09638938] [2.35162738 2.57830882] [2.44816857 2.30040541] [2.01151561 2.06731494] [2.37155893 2.21960381] [2.99178718 2.52187589] [1.65763119 1.89560817] [2.96260062 2.25337757] [2.63163201 2.64304136] [2.20083775 2.2276421 ] [2.75524721 2.43557852] [2.470283 2.20606519] [2.92995372 2.57566623] [2.42745839 2.52380305] [2.48647352 2.33200857] [2.3895279 2.19181503] [2.73213327 2.55828373] [2.38017135 2.25598459] [2.07922724 1.8996458 ] [2.88058391 2.9694922 ] [2.16396599 1.91148607] [2.76792949 2.54237517] [2.35162738 2.04195274] [2.51355842 2.30040541] [2.01151561 1.70160283] [2.45618929 2.21960381] [1.65763119 1.78583052] [2.63163201 2.70935241] [2.42745839 2.04952202] [2.75726895 2.55828373] [2.45988615 2.25598459] [2.23967778 1.91148607] [2.25202635 2.04195274] [2.42429494 2.30040541] [2.20367775 1.70160283] [1.65763119 1.65821152] [2.37613044 2.04195274] [2.38163517 2.30040541] [1.65763119 1.68156906] [2.46982378 2.30040541] [1.65763119 1.76957808] [1.65763119 1.76896305] [1.65763119 1.72210251] [1.65763119 1.70304922]
0%| | 0/50 [00:00<?, ?it/s]
[3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.28970725 3.28970725] [3.61954236 3.28970725] [3.28970725 2.10230224] [3.28970725 1.68757613] [3.28970725 1.73281824] [3.28970725 1.89482357] [3.44930252 3.28970725] [3.77767416 3.28970725] [3.36575843 3.28970725] [2.40229346 3.28970725] [3.28970725 1.84953347] [3.15370036 3.28970725] [3.28970725 2.2207267 ] [4.27293551 3.28970725] [3.28970725 2.10139188] [2.72972132 3.28970725] [3.25700376 3.28970725] [3.28970725 2.57302619] [3.28970725 1.54313908] [3.28970725 1.92191388] [3.76797731 3.28970725] [3.28970725 2.41437005] [3.28970725 2.20606519] [3.28970725 1.40577887] [3.38990078 3.28970725] [3.28970725 1.91971212] [3.28970725 1.51226677] [3.28970725 1.76947968] [4.04917429 3.28970725] [3.32910156 3.28970725] [3.28970725 1.81691371] [3.72721279 3.28970725] [3.28970725 1.96420362] [3.85842181 3.28970725] [3.28970725 1.73492909] [3.28970725 2.3192803 ] [4.08009401 3.28970725] [3.45584735 3.28970725] [3.7092569 3.28970725] [3.28970725 2.4482715 ] [3.32966607 3.28970725] [3.28970725 1.37933528] [3.28621837 3.28970725] [3.28970725 1.75311485] [3.15839042 3.28970725] [4.22495039 3.28970725] [3.28970725 1.87732424] [3.58645139 3.28970725] [2.61579228 3.28970725] [4.05292199 3.28970725] [3.71781508 3.28970725] [3.74514972 3.28970725] [3.55426596 3.28970725] [3.28970725 1.8492947 ] [3.27331107 3.28970725] [3.80386642 3.28970725] [3.28970725 1.54683802] [3.28970725 2.3376584 ] [3.28970725 2.55935095] [3.28970725 1.12760385] [3.28970725 2.17969147] [3.28970725 1.39754359] [3.00732665 3.28970725] [2.86725635 3.28970725] [3.28970725 2.33910739] [3.55964698 3.28970725] [3.27212461 3.28970725] [3.10612937 3.28970725] [3.65114568 3.28970725] [4.38169537 3.28970725] [3.28970725 2.1064876 ] [3.28970725 2.04738365] [3.73070505 3.28970725] [2.99729932 3.28970725] [3.57491006 3.28970725] [3.60174972 3.28970725] [3.69237126 3.28970725] [4.20904183 3.28970725] [3.8464232 3.28970725] [3.28970725 1.74457456] [3.76305605 3.28970725] [3.18496072 3.28970725] [3.9780022 3.28970725] [3.96707207 3.28970725] [3.15749819 3.28970725] [3.28970725 1.17818228] [3.28970725 1.62489326] [3.36606829 3.28970725] [3.87327794 3.28970725] [3.20489226 3.28970725] [3.84031433 3.28970725] [3.28970725 2.15845385] [3.52327428 3.28970725] [2.49096453 3.28970725] [4.03293602 3.28970725] [3.28970725 2.13584086] [3.28970725 2.12926729] [3.28970725 1.49912721] [3.28970725 1.62173627] [3.28970725 1.80987343] [3.35200111 3.28970725] [3.68954924 3.28970725] [3.6432628 3.28970725] [4.11223508 3.28970725] [3.42862269 3.28970725] [2.72972132 2.24126513] [3.6803892 3.28970725] [3.19011963 2.41437005] [3.60100316 3.28970725] [3.44405983 3.28970725] [4.04796063 3.28970725] [3.50683158 3.28970725] [3.82959375 3.28970725] [3.52096268 3.28970725] [3.54716605 3.28970725] [3.69821278 3.28970725] [3.49598318 3.28970725] [3.93891592 3.28970725] [3.3116269 3.28970725] [3.8399724 3.28970725] [3.4448092 3.28970725] [3.27331107 2.73014367] [3.87038624 3.28970725] [3.5584082 3.28970725] [3.81533442 3.28970725] [3.56063296 3.28970725] [3.37582657 3.28970725] [3.66808653 3.28970725] [3.9645547 3.28970725] [3.55169415 3.28970725] [3.60173579 3.28970725] [3.18496072 2.57830882] [3.65470333 3.28970725] [3.48321491 3.28970725] [3.84329203 3.28970725] [3.81147803 3.28970725] [3.39902263 3.28970725] [2.49096453 1.89560817] [3.87329257 3.28970725] [3.53193345 3.28970725] [3.40235174 3.28970725] [3.63651471 3.28970725] [3.49572132 3.28970725] [3.76041277 3.28970725] [3.63859979 3.28970725] [3.62644441 3.28970725] [3.60768729 3.28970725] [3.76543712 3.28970725] [3.60336656 3.28970725] [3.53226252 3.28970725] [3.62834705 3.28970725] [3.43079065 3.28970725] [3.68075754 3.28970725] [3.48127551 3.28970725] [3.62949303 3.28970725] [3.35683354 3.28970725] [3.60508509 3.28970725] [3.65536619 3.28970725] [3.35285869 3.28970725] [3.8268844 3.28970725] [3.8322275 3.28970725] [3.62759041 3.28970725] [3.71588574 3.28970725] [3.63511674 3.28970725] [3.46429438 3.28970725] [3.82109632 3.28970725] [3.7277351 3.28970725] [3.60416946 3.28970725] [3.92976818 3.28970725] [3.779973 3.28970725] [3.45348635 3.28970725] [3.44343584 3.28970725] [3.32659627 3.28970725] [3.69769815 3.28970725] [3.5677129 3.28970725] [3.69031652 3.28970725] [3.54772823 3.28970725] [3.69820156 3.28970725] [3.49538529 3.28970725] [3.54007351 3.28970725] [3.73237045 3.28970725] [3.53602584 3.28970725] [3.68440632 3.28970725] [3.38243491 3.28970725] [3.62701676 3.28970725] [3.50125698 3.28970725] [3.66064491 3.28970725] [3.63616419 3.28970725] [3.41070493 3.28970725] [3.75060632 3.28970725] [3.65900988 3.28970725] [3.62076431 3.28970725] [3.59008212 3.28970725] [3.53816401 3.28970725] [3.7684615 3.28970725] [3.70239419 3.28970725] [3.60848906 3.28970725] [3.78927997 3.28970725] [3.7449954 3.28970725] [3.56058446 3.28970725] [3.46667717 3.28970725] [3.27563587 3.28970725] [3.66248797 3.28970725] [3.48562607 3.28970725] [3.65344189 3.28970725] [3.56393786 3.28970725] [3.56075463 3.28970725] [3.44086845 3.28970725] [3.57416062 3.28970725] [3.63569958 3.28970725] [3.56986549 3.28970725] [3.79903671 3.28970725] [3.21643759 3.28970725] [3.52507872 3.28970725] [3.51364014 3.28970725] [3.65370431 3.28970725] [3.66626889 3.28970725] [3.46872736 3.28970725] [3.74853145 3.28970725] [3.71109201 3.28970725] [3.40397009 3.28970725] [3.52316917 3.28970725] [3.48181014 3.28970725] [3.65535185 3.28970725] [3.52281385 3.28970725] [3.63121402 3.28970725] [3.78902516 3.28970725] [3.51582971 3.28970725] [3.4236798 3.28970725] [3.50907003 3.28970725] [3.61992455 3.28970725] [3.72811233 3.28970725] [3.62575015 3.28970725] [3.58565335 3.28970725] [3.4398208 3.28970725] [3.57069886 3.28970725] [3.59009141 3.28970725] [3.62065736 3.28970725] [3.9843109 3.28970725] [3.66424871 3.28970725] [3.57887413 3.28970725] [3.62877434 3.28970725] [3.63331735 3.28970725] [3.51627643 3.28970725] [3.72534549 3.28970725] [3.69196438 3.28970725] [3.45556246 3.28970725] [3.5990571 3.28970725] [3.53961753 3.28970725] [3.64753884 3.28970725] [3.53205887 3.28970725] [3.61413761 3.28970725] [3.73555834 3.28970725] [3.51645764 3.28970725] [3.38054065 3.28970725] [3.48077838 3.28970725] [3.64564882 3.28970725] [3.66218852 3.28970725] [3.73363647 3.28970725] [3.6228635 3.28970725] [3.54646247 3.28970725] [3.54835296 3.28970725] [3.67278602 3.28970725] [3.59552661 3.28970725] [3.59628957 3.28970725] [3.561294 3.28970725] [3.56223149 3.28970725] [3.65022238 3.28970725] [3.49095276 3.28970725] [3.75539233 3.28970725] [3.69478159 3.28970725] [3.53063714 3.28970725] [3.66347642 3.28970725] [3.55641383 3.28970725] [3.67075326 3.28970725] [3.5336596 3.28970725] [3.47521144 3.28970725] [3.80589616 3.28970725] [3.57722147 3.28970725] [3.32186438 3.28970725] [3.42250935 3.28970725] [3.56810322 3.28970725] [3.66961415 3.28970725] [3.74515388 3.28970725] [3.72306898 3.28970725] [3.54170229 3.28970725] [3.55294894 3.28970725] [3.76134607 3.28970725] [3.64933276 3.28970725] [3.58736738 3.28970725] [3.71215351 3.28970725] [3.49298415 3.28970725] [3.70676046 3.28970725] [3.56411594 3.28970725] [3.79761408 3.28970725] [3.62537905 3.28970725] [3.52196116 3.28970725] [3.67184792 3.28970725] [3.57219259 3.28970725] [3.66301659 3.28970725] [3.57166898 3.28970725] [3.44974749 3.28970725] [3.82383064 3.28970725] [3.64210542 3.28970725] [3.3365258 3.28970725] [3.46400122 3.28970725] [3.62894845 3.28970725] [3.63267062 3.28970725] [3.69809931 3.28970725] [3.76253778 3.28970725] [3.57412259 3.28970725] [3.53180688 3.28970725] [3.70934412 3.28970725] [3.72404677 3.28970725] [3.58130391 3.28970725] [3.71727782 3.28970725] [3.46914003 3.28970725] [3.657989 3.28970725] [3.66790936 3.28970725] [3.74883047 3.28970725] [3.59957392 3.28970725] [3.49797761 3.28970725] [3.65785301 3.28970725] [3.63599403 3.28970725] [3.79911661 3.28970725] [3.5662002 3.28970725] [3.51004322 3.28970725] [3.84401051 3.28970725]
In [10]:
# Todo: make the output table in a way that we chose the maximum index
def debug_data_to_torch_distribution(debug_data):
mean = debug_data["mean"]
# + the true variance of 1
variance = numpy.array(debug_data["variance"]) + 1
cov = torch.diag_embed(torch.tensor(variance))
return torch.distributions.MultivariateNormal(torch.tensor(mean), cov)
def data_to_true_distribution(data):
mean = data.additional_config["expectations_of_interventions"]
cov = torch.eye(len(mean))
return torch.distributions.MultivariateNormal(torch.tensor(mean), cov)
metrics = [
SimpleRegretWithMean(),
BestArmIdentification(),
CumulativeRegret(),
Length(),
KLDivergence(data_to_true_distribution = data_to_true_distribution, debug_data_to_posterior_distribution=debug_data_to_torch_distribution),
]
model_mapping = {
"NormalModel(([0, 0], [1, 1]))": "I",
"NormalModel(([1, 0], [1, 1]))": "II",
"NormalModel(([2, 0], [1, 1]))": "III",
}
policy_mapping = {
"StoppingPolicy(BlockPolicy(FixedPolicy))": "Fixed",
"StoppingPolicy(BlockPolicy(ThompsonSampling(NormalKnownVariance(0, 1, 1))))": "TS",
"StoppingPolicy(BlockPolicy(UpperConfidenceBound(0.05 epsilon, NormalKnownVariance(0, 1, 1))))": "UCB",
"StoppingPolicy(BlockPolicy(ExploreThenCommit(4,NormalKnownVariance(0, 1, 1))))": "ETC",
}
df = SeriesOfSimulationsData.score_data(
[s["result"] for s in calculated_series], metrics, {"model": lambda x: model_mapping[x], "policy": lambda x: policy_mapping[x]}
)
df = df.reset_index(drop=True)
max_t_indices = df.groupby(["policy", "metric", "model", "patient_id"])["t"].idxmax()
filtered_df = df.iloc[max_t_indices]
filtered_df = filtered_df.reset_index(drop=True)
groupby_columns = ["model", "policy"]
pivoted_df = filtered_df.pivot(
index=["model", "policy", "patient_id"],
columns="metric",
values="score",
)
table = pivoted_df.groupby(groupby_columns).agg(['mean', 'std'])
policy_ordering = ["Fixed", "ETC", "SH", "UCB", "TS"]
# Convert the 'policy' column in the MultiIndex to a Categorical type with the specified order
table = table.reset_index()
table['policy'] = pd.Categorical(table['policy'], categories=policy_ordering, ordered=True)
# Sort the DataFrame first by 'model' then by the now-ordered 'policy'
sorted_table = table.sort_values(by=['model', 'policy']).set_index(groupby_columns)[["Cumulative Regret (outcome)", "KL Divergence", "Simple Regret With Mean", "Length", "Best Arm Identification With Mean"]].rename(
columns={"Cumulative Regret (outcome)": "Regret", "Simple Regret With Mean": "$\SIR$", "Best Arm Identification With Mean": "$\BAI$", "KL Divergence": "$\KLD$"},
)
sorted_table.index.names = ["S.", "Policy"]
sorted_table
Out[10]:
| metric | Regret | $\KLD$ | $\SIR$ | Length | $\BAI$ | ||||||
|---|---|---|---|---|---|---|---|---|---|---|---|
| mean | std | mean | std | mean | std | mean | std | mean | std | ||
| S. | Policy | ||||||||||
| I | Fixed | -0.037432 | 3.597176 | 0.163983 | 0.165737 | 0.0 | 0.000000 | 16.52 | 8.818575 | 0.48 | 0.502117 |
| ETC | -0.292172 | 4.041548 | 0.166418 | 0.164108 | 0.0 | 0.000000 | 17.17 | 11.315543 | 0.48 | 0.502117 | |
| UCB | 0.309196 | 4.454360 | 0.17932 | 0.170011 | 0.0 | 0.000000 | 18.45 | 12.693254 | 0.46 | 0.500908 | |
| TS | -0.147157 | 4.328133 | 0.150667 | 0.159053 | 0.0 | 0.000000 | 22.42 | 14.843412 | 0.52 | 0.502117 | |
| II | Fixed | -6.732649 | 4.478590 | 0.219876 | 0.134231 | 0.04 | 0.196946 | 11.08 | 4.856902 | 0.96 | 0.196946 |
| ETC | -6.911719 | 4.506643 | 0.218885 | 0.135457 | 0.05 | 0.219043 | 11.27 | 4.921310 | 0.95 | 0.219043 | |
| UCB | -6.502219 | 5.637445 | 0.403294 | 0.284822 | 0.14 | 0.348735 | 11.6 | 6.539700 | 0.86 | 0.348735 | |
| TS | -12.266483 | 14.320192 | 0.411553 | 0.281109 | 0.22 | 0.416333 | 18.79 | 12.119185 | 0.78 | 0.416333 | |
| III | Fixed | -10.109516 | 2.809726 | 0.376703 | 0.399681 | 0.04 | 0.281411 | 7.43 | 1.289076 | 0.98 | 0.140705 |
| ETC | -10.243181 | 2.724291 | 0.354477 | 0.347126 | 0.02 | 0.200000 | 7.55 | 1.366075 | 0.99 | 0.100000 | |
| UCB | -33.577203 | 40.023253 | 1.111306 | 0.969029 | 0.38 | 0.788554 | 19.59 | 18.100781 | 0.81 | 0.394277 | |
| TS | -41.302591 | 42.305474 | 1.16179 | 1.023429 | 0.5 | 0.870388 | 25.07 | 18.014167 | 0.76 | 0.429235 | |
In [11]:
with open('mt_resources/7-stopping/01-table.tex', 'w') as file:
str = sorted_table.style.format(precision=2).to_latex(hrules=True)
print(str)
file.write(str)
\begin{tabular}{lllrlrlrlrlr}
\toprule
& metric & \multicolumn{2}{r}{Regret} & \multicolumn{2}{r}{$\KLD$} & \multicolumn{2}{r}{$\SIR$} & \multicolumn{2}{r}{Length} & \multicolumn{2}{r}{$\BAI$} \\
& & mean & std & mean & std & mean & std & mean & std & mean & std \\
S. & Policy & & & & & & & & & & \\
\midrule
\multirow[c]{4}{*}{I} & Fixed & -0.04 & 3.60 & 0.16 & 0.17 & 0.00 & 0.00 & 16.52 & 8.82 & 0.48 & 0.50 \\
& ETC & -0.29 & 4.04 & 0.17 & 0.16 & 0.00 & 0.00 & 17.17 & 11.32 & 0.48 & 0.50 \\
& UCB & 0.31 & 4.45 & 0.18 & 0.17 & 0.00 & 0.00 & 18.45 & 12.69 & 0.46 & 0.50 \\
& TS & -0.15 & 4.33 & 0.15 & 0.16 & 0.00 & 0.00 & 22.42 & 14.84 & 0.52 & 0.50 \\
\multirow[c]{4}{*}{II} & Fixed & -6.73 & 4.48 & 0.22 & 0.13 & 0.04 & 0.20 & 11.08 & 4.86 & 0.96 & 0.20 \\
& ETC & -6.91 & 4.51 & 0.22 & 0.14 & 0.05 & 0.22 & 11.27 & 4.92 & 0.95 & 0.22 \\
& UCB & -6.50 & 5.64 & 0.40 & 0.28 & 0.14 & 0.35 & 11.60 & 6.54 & 0.86 & 0.35 \\
& TS & -12.27 & 14.32 & 0.41 & 0.28 & 0.22 & 0.42 & 18.79 & 12.12 & 0.78 & 0.42 \\
\multirow[c]{4}{*}{III} & Fixed & -10.11 & 2.81 & 0.38 & 0.40 & 0.04 & 0.28 & 7.43 & 1.29 & 0.98 & 0.14 \\
& ETC & -10.24 & 2.72 & 0.35 & 0.35 & 0.02 & 0.20 & 7.55 & 1.37 & 0.99 & 0.10 \\
& UCB & -33.58 & 40.02 & 1.11 & 0.97 & 0.38 & 0.79 & 19.59 & 18.10 & 0.81 & 0.39 \\
& TS & -41.30 & 42.31 & 1.16 & 1.02 & 0.50 & 0.87 & 25.07 & 18.01 & 0.76 & 0.43 \\
\bottomrule
\end{tabular}
In [12]:
def rename_df(df):
df["policy_#_metric_#_model_p"] = df["policy"].apply(lambda x: policy_mapping[x])
return df
SeriesOfSimulationsData.plot_lines(
[s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
[
CumulativeRegret(),
],
legend_position=(0.02,0.3),
process_df = rename_df,
)
plt.ylabel('Regret')
plt.savefig("mt_resources/7-stopping/01_cumulative_regret.pdf", bbox_inches="tight")
In [13]:
SeriesOfSimulationsData.plot_lines(
[s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
[
SimpleRegretWithMean(),
],
legend_position=(0.8,1.0),
process_df = rename_df,
)
plt.ylabel('Simple Regret')
plt.savefig("mt_resources/7-stopping/01_simple_regret.pdf", bbox_inches="tight")
In [14]:
SeriesOfSimulationsData.plot_lines(
[s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
[
KLDivergence(data_to_true_distribution = data_to_true_distribution, debug_data_to_posterior_distribution=debug_data_to_torch_distribution)
],
legend_position=(0.8,1.0),
process_df = rename_df,
)
plt.ylabel('KL Divergence')
plt.savefig("mt_resources/7-stopping/01-kl-divergence.pdf", bbox_inches="tight")
In [15]:
df = SeriesOfSimulationsData.score_data(
[s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
[ IsStopped() ],
)
df["policy"] = df["policy"].apply(lambda x: policy_mapping[x])
groupby_df_sum = df.groupby(["policy", "model", "t"]).sum()
ax = seaborn.lineplot(
data=groupby_df_sum,
x="t",
y="score",
hue="policy",
# units="patient_id",
#estimator=numpy.median,
#errorbar=lambda x: (numpy.quantile(x, 0.25), numpy.quantile(x, 0.75)),
)
plt.ylabel("Number of patients")
seaborn.move_legend(ax, "upper right", title=None)
plt.savefig("mt_resources/7-stopping/01_is_stopped.pdf", bbox_inches="tight")
In [16]:
plot_allocations_for_calculated_series(calculated_series)
/opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_drag' property; using the latest value layout_plot = gridplot( /opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_scroll' property; using the latest value layout_plot = gridplot(
Out[16]:
In [17]:
plot_allocations_for_calculated_series([s for s in calculated_series if s["configuration"]["policy"] == "StoppingPolicy(BlockPolicy(UpperConfidenceBound(0.05 epsilon, NormalKnownVariance(0, 1, 1))))" and s["configuration"]["model"] == "NormalModel(([0, 0], [1, 1]))"])
Out[17]:
:Layout
In [18]:
plot_allocations_for_calculated_series([s for s in calculated_series if s["configuration"]["policy"] == "StoppingPolicy(BlockPolicy(UpperConfidenceBound(0.05 epsilon, NormalKnownVariance(0, 1, 1))))" and s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"])
Out[18]:
:Layout
In [21]:
import param
from functools import reduce
class PatientExplorer(param.Parameterized):
patient_id = param.Integer(default=0, bounds=(0, number_of_patients - 1))
configuration = param.Integer(default=0, bounds=(0, len(calculated_series) - 1))
t = param.Integer(default=10, bounds=(0, max_length - 1))
@param.depends("patient_id", "configuration")
def hvplot(self):
debug_data = (
calculated_series[self.configuration]["result"]
.simulations[self.patient_id]
.history.debug_data()
)
df = pandas.DataFrame([flatten_dictionary(d) for d in debug_data])
return df.hvplot()
@param.depends("configuration")
def configuration_name(self):
return panel.panel(calculated_series[self.configuration]["configuration"])
In [22]:
import holoviews
import panel
explorer = PatientExplorer()
hvplot = holoviews.DynamicMap(explorer.hvplot)
panel.Column(
panel.Row(panel.Column(explorer.param, explorer.configuration_name), hvplot),
)
Out[22]:
In [ ]: